# Leveraging KV Similarity for Online Structured Pruning in LLMs

This is an implementation of *Leveraging KV Similarity for Online Structured Pruning in LLMs*

## Installation

This code is based on the following library versions:

- transformers==4.54.0
- lm-eval==0.4.9

- Step 1: Create a new conda environment:
  ```ruby
  conda create -n token_filtering python=3.9
  conda activate token_filtering
  ```

- Step 2: Install relevant packages
  ```ruby
  git clone https://github.com/EleutherAI/lm-evaluation-harness.git
  cd lm-evaluation-harness
  pip install -r requirements.txt
  pip install transformers==4.54.0
  ```

## Modified Code
Our experiments required a modification of HuggingFace’s LLaMA implementation.
The modified file is:

- transformers/models/llama/modeling_llama.py
- transformers/models/mistral/modeling_mistral.py
- transformers/models/phi3/modeling_phi3.py

How to apply:

1. Install the specified libraries.
2. Replace the original modeling_llama.py with the provided file in the same path (transformers/models/llama/).

## Model Download
We recommend downloading the model weights locally before running experiments.
A helper script download.py is provided for convenience:
  ```ruby
  python download.py
  ```
To switch to a different model, simply change the model name inside download.py and run it again.


## Hyperparameter Settings (Pruning)
The Pruning Ratio and Warmup Tokens can be controlled via the hyper.py file.
  ```ruby
  target_skip_ratio = 0.33   # Fraction of layers to prune (e.g., 0.33 → ~33% pruning)
  warmup_tokens     = 1      # Number of initial tokens computed without pruning
  ```

- target_skip_ratio

  Controls the fraction of layers that are skipped (pruned) during forward passes.

  Examples:

  0.0 → No pruning (baseline)

  0.25 → Skip ~25% of layers

  0.5 → Skip ~50% of layers

- warmup_tokens

  Number of tokens at the beginning of generation that are computed with all layers (no pruning).

  Examples:

  1 → Only the first token uses all layers; pruning starts from the 2nd token

  5 → First 5 tokens are computed fully, then pruning is applied


## Examples

- Test commonsense reasoning benchmark of LLaMA-2-13B.

  ```ruby
  cd tf_scripts
  ./llama2_13b_acc.sh
  ```

- Test perplexity of LLaMA-2-13B.

  ```ruby
  cd ppl_test
  python3 ppl_test_llama.py
  ```

- Test latency of LLaMA-2-13B
  ```ruby
  cd ppl_test
  python3 latency_test_llama.py
  ```

## Results

*Zero-shot Performance of LLaMA-2-7B/13B and OPT-13B After Pruning Attention and MLP Blocks Without Fine-Tuning, PP demonstrates superior performance in nearly all scenarios.*

| Method                | Pruning Ratio | LLaMA-2-13B (Text Generation) ↓ | LLaMA-2-13B (Commonsense Reasoning) ↑ | 
| ----------------------| ------------- | ------------------------------- | ------------------------------------- | 
| **Dense**             | 0%            | 10.98                           | 69.51                                 | 
| **SlimGPT w/o**       | 20%           | 13.80                           | 67.43                                 | 
| **FLAP**              | 20%           | 14.13                           | 66.89                                 | 
| **PP**                | 20%           | **12.52**                       | 66.33                                 | 
| **Token Filtering**   | 20%           | 13.37                           | **68.69**                             | 
| **SlimGPT w/o**       | 50%           | 32.67                           | 57.15                                 | 
| **FLAP**              | 50%           | 29.45                           | 57.41                                 | 
| **PP**                | 50%           | **28.86**                       | 51.93                                 | 
| **Token Filtering**   | 50%           | 29.22                           | **65.90**                             | 

